{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![MLU Logo](../data/MLU_Logo.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# <a name=\"0\">Machine Learning Accelerator - Tabular Data - Lecture 3</a>\n",
    "\n",
    "\n",
    "## AutoGluon\n",
    "\n",
    "In this notebook, we use __AutoGluon__ to predict the __Outcome Type__ field of our review dataset.\n",
    "\n",
    "\n",
    "[AutoGluon](https://auto.gluon.ai/stable/index.html) implements many of the best practices that we have discussed in this class, and more!  In particular, it sets itself apart from other AutoML solutions by having excellent automated feature engineering that can handle text data and missing values without any hand-coded solutions (See their [paper](https://arxiv.org/abs/2003.06505) for details).  It is too new to be in an existing Sagemaker kernel, so let's install it.\n",
    "\n",
    "1. <a href=\"#1\">Set up AutoGluon</a>\n",
    "2. <a href=\"#2\">Read the datasets</a>\n",
    "3. <a href=\"#3\">Train a classifier with AutoGluon</a>\n",
    "4. <a href=\"#4\">Model evaluation</a>\n",
    "5. <a href=\"#5\">Clean up model artifacts</a>\n",
    "\n",
    "__Austin Animal Center Dataset__:\n",
    "\n",
    "In this exercise, we are working with pet adoption data from __Austin Animal Center__. We have two datasets that cover intake and outcome of animals. Intake data is available from [here](https://data.austintexas.gov/Health-and-Community-Services/Austin-Animal-Center-Intakes/wter-evkm) and outcome is from [here](https://data.austintexas.gov/Health-and-Community-Services/Austin-Animal-Center-Outcomes/9t4d-g238). \n",
    "\n",
    "In order to work with a single table, we joined the intake and outcome tables using the \"Animal ID\" column and created a single __review.csv__ file. We also didn't consider animals with multiple entries to the facility to keep our dataset simple. If you want to see the original datasets and the merged data with multiple entries, they are available under data/review folder: Austin_Animal_Center_Intakes.csv, Austin_Animal_Center_Outcomes.csv and Austin_Animal_Center_Intakes_Outcomes.csv.\n",
    "\n",
    "__Dataset schema:__ \n",
    "- __Pet ID__ - Unique ID of pet\n",
    "- __Outcome Type__ - State of pet at the time of recording the outcome (0 = not placed, 1 = placed). This is the field to predict.\n",
    "- __Sex upon Outcome__ - Sex of pet at outcome\n",
    "- __Name__ - Name of pet \n",
    "- __Found Location__ - Found location of pet before entered the center\n",
    "- __Intake Type__ - Circumstances bringing the pet to the center\n",
    "- __Intake Condition__ - Health condition of pet when entered the center\n",
    "- __Pet Type__ - Type of pet\n",
    "- __Sex upon Intake__ - Sex of pet when entered the center\n",
    "- __Breed__ - Breed of pet \n",
    "- __Color__ - Color of pet \n",
    "- __Age upon Intake Days__ - Age of pet when entered the center (days)\n",
    "- __Age upon Outcome Days__ - Age of pet at outcome (days))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. <a name=\"1\">Set up AutoGluon</a>\n",
    "(<a href=\"#0\">Go to top</a>)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "%pip install -q -r ../requirements.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. <a name=\"2\">Read the dataset</a>\n",
    "(<a href=\"#0\">Go to top</a>)\n",
    "\n",
    "Let's read the dataset into a dataframe, using Pandas, and split the dataset into train and test sets (AutoGluon will handle the validation itself)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.read_csv('../data/review/review_dataset.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "train_data, test_data = train_test_split(df, test_size=0.1, shuffle=True, random_state=23)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. <a name=\"3\">Train a classifier with AutoGluon</a>\n",
    "(<a href=\"#0\">Go to top</a>)\n",
    "\n",
    "We can run AutoGluon with a short snippet. For fitting, we just call the __.fit()__ function. In this exercise, we used the data frame objects, but this tool also accepts the raw csv files as input. To use this tool with simple csv files, you can follow the code snippet below.\n",
    "\n",
    "```python\n",
    "from autogluon.tabular import TabularDataset, TabularPredictor\n",
    "\n",
    "train_data = TabularDataset(file_path='path_to_dataset/train.csv')\n",
    "test_data = TabularDataset(file_path='path_to_dataset/test.csv')\n",
    "\n",
    "predictor = TabularPredictor(label='label_column').fit(train_data)\n",
    "test_predictions = predictor.predict(test_data)\n",
    "```\n",
    "\n",
    "We have our separate __data frames__ for training and test data, so we work with them below. We grab the first 10000 data points for a quick demo. You can also pass the full dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No path specified. Models will be saved in: \"AutogluonModels/ag-20241001_183703/\"\n",
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/core/utils/utils.py:549: FutureWarning: use_inf_as_na option is deprecated and will be removed in a future version. Convert inf values to NaN before operating instead.\n",
      "  with pd.option_context(\"mode.use_inf_as_na\", True):  # treat None, NaN, INF, NINF as NA\n",
      "Beginning AutoGluon training ...\n",
      "AutoGluon will save models to \"AutogluonModels/ag-20241001_183703/\"\n",
      "AutoGluon Version:  0.8.3\n",
      "Python Version:     3.10.14\n",
      "Operating System:   Linux\n",
      "Platform Machine:   x86_64\n",
      "Platform Version:   #1 SMP Tue Sep 10 22:02:55 UTC 2024\n",
      "Disk Space Avail:   10.77 GB / 26.83 GB (40.1%)\n",
      "Train Data Rows:    10000\n",
      "Train Data Columns: 12\n",
      "Label Column: Outcome Type\n",
      "Preprocessing data ...\n",
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/core/utils/utils.py:549: FutureWarning: use_inf_as_na option is deprecated and will be removed in a future version. Convert inf values to NaN before operating instead.\n",
      "  with pd.option_context(\"mode.use_inf_as_na\", True):  # treat None, NaN, INF, NINF as NA\n",
      "AutoGluon infers your prediction problem is: 'binary' (because only two unique label-values observed).\n",
      "\t2 unique label values:  [1.0, 0.0]\n",
      "\tIf 'binary' is not the correct problem_type, please manually specify the problem_type parameter during predictor init (You may specify problem_type as one of: ['binary', 'multiclass', 'regression'])\n",
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/tabular/learner/default_learner.py:215: FutureWarning: use_inf_as_na option is deprecated and will be removed in a future version. Convert inf values to NaN before operating instead.\n",
      "  with pd.option_context(\"mode.use_inf_as_na\", True):  # treat None, NaN, INF, NINF as NA\n",
      "Selected class <--> label mapping:  class 1 = 1, class 0 = 0\n",
      "Using Feature Generators to preprocess the data ...\n",
      "Fitting AutoMLPipelineFeatureGenerator...\n",
      "\tAvailable Memory:                    12980.59 MB\n",
      "\tTrain Data (Original)  Memory Usage: 6.86 MB (0.1% of available memory)\n",
      "\tInferring data type of each feature based on column values. Set feature_metadata_in to manually specify special dtypes of the features.\n",
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/common/features/infer_types.py:118: UserWarning: Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.\n",
      "  result = pd.to_datetime(X, errors=\"coerce\")\n",
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/common/features/infer_types.py:118: UserWarning: Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.\n",
      "  result = pd.to_datetime(X, errors=\"coerce\")\n",
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/common/features/infer_types.py:118: UserWarning: Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.\n",
      "  result = pd.to_datetime(X, errors=\"coerce\")\n",
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/common/features/infer_types.py:118: UserWarning: Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.\n",
      "  result = pd.to_datetime(X, errors=\"coerce\")\n",
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/common/features/infer_types.py:118: UserWarning: Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.\n",
      "  result = pd.to_datetime(X, errors=\"coerce\")\n",
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/common/features/infer_types.py:118: UserWarning: Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.\n",
      "  result = pd.to_datetime(X, errors=\"coerce\")\n",
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/common/features/infer_types.py:118: UserWarning: Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.\n",
      "  result = pd.to_datetime(X, errors=\"coerce\")\n",
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/common/features/infer_types.py:118: UserWarning: Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.\n",
      "  result = pd.to_datetime(X, errors=\"coerce\")\n",
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/common/features/infer_types.py:118: UserWarning: Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.\n",
      "  result = pd.to_datetime(X, errors=\"coerce\")\n",
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/common/features/infer_types.py:118: UserWarning: Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.\n",
      "  result = pd.to_datetime(X, errors=\"coerce\")\n",
      "\tStage 1 Generators:\n",
      "\t\tFitting AsTypeFeatureGenerator...\n",
      "\tStage 2 Generators:\n",
      "\t\tFitting FillNaFeatureGenerator...\n",
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/features/generators/fillna.py:58: FutureWarning: The 'downcast' keyword in fillna is deprecated and will be removed in a future version. Use res.infer_objects(copy=False) to infer non-object dtype, or pd.to_numeric with the 'downcast' keyword to downcast numeric results.\n",
      "  X.fillna(self._fillna_feature_map, inplace=True, downcast=False)\n",
      "\tStage 3 Generators:\n",
      "\t\tFitting IdentityFeatureGenerator...\n",
      "\t\tFitting CategoryFeatureGenerator...\n",
      "\t\t\tFitting CategoryMemoryMinimizeFeatureGenerator...\n",
      "\t\tFitting TextSpecialFeatureGenerator...\n",
      "\t\t\tFitting BinnedFeatureGenerator...\n",
      "\t\t\tFitting DropDuplicatesFeatureGenerator...\n",
      "\t\tFitting TextNgramFeatureGenerator...\n",
      "\t\t\tFitting CountVectorizer for text features: ['Found Location']\n",
      "\t\t\tCountVectorizer fit with vocabulary size = 198\n",
      "\tStage 4 Generators:\n",
      "\t\tFitting DropUniqueFeatureGenerator...\n",
      "\tStage 5 Generators:\n",
      "\t\tFitting DropDuplicatesFeatureGenerator...\n",
      "\tUnused Original Features (Count: 1): ['Pet ID']\n",
      "\t\tThese features were not used to generate any of the output features. Add a feature generator compatible with these features to utilize them.\n",
      "\t\tFeatures can also be unused if they carry very little information, such as being categorical but having almost entirely unique values or being duplicates of other features.\n",
      "\t\tThese features do not need to be present at inference time.\n",
      "\t\t('object', []) : 1 | ['Pet ID']\n",
      "\tTypes of features in original data (raw dtype, special dtypes):\n",
      "\t\t('int', [])          : 2 | ['Age upon Intake Days', 'Age upon Outcome Days']\n",
      "\t\t('object', [])       : 8 | ['Sex upon Outcome', 'Name', 'Intake Type', 'Intake Condition', 'Pet Type', ...]\n",
      "\t\t('object', ['text']) : 1 | ['Found Location']\n",
      "\tTypes of features in processed data (raw dtype, special dtypes):\n",
      "\t\t('category', [])                    :   8 | ['Sex upon Outcome', 'Name', 'Intake Type', 'Intake Condition', 'Pet Type', ...]\n",
      "\t\t('category', ['text_as_category'])  :   1 | ['Found Location']\n",
      "\t\t('int', [])                         :   2 | ['Age upon Intake Days', 'Age upon Outcome Days']\n",
      "\t\t('int', ['binned', 'text_special']) :  12 | ['Found Location.char_count', 'Found Location.word_count', 'Found Location.capital_ratio', 'Found Location.lower_ratio', 'Found Location.digit_ratio', ...]\n",
      "\t\t('int', ['text_ngram'])             : 176 | ['__nlp__.183', '__nlp__.183 and', '__nlp__.183 in', '__nlp__.1st', '__nlp__.290', ...]\n",
      "\t10.9s = Fit runtime\n",
      "\t11 features in original data used to generate 199 features in processed data.\n",
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/common/utils/pandas_utils.py:50: FutureWarning: Setting an item of incompatible dtype is deprecated and will raise in a future error of pandas. Value '20637.401015228428' has dtype incompatible with int64, please explicitly cast to a compatible dtype first.\n",
      "  memory_usage[column] = (\n",
      "\tTrain Data (Processed) Memory Usage: 3.94 MB (0.0% of available memory)\n",
      "Data preprocessing and feature engineering runtime = 11.05s ...\n",
      "AutoGluon will gauge predictive performance using evaluation metric: 'accuracy'\n",
      "\tTo change this, specify the eval_metric parameter of Predictor()\n",
      "Automatically generating train/validation split with holdout_frac=0.1, Train Rows: 9000, Val Rows: 1000\n",
      "User-specified model hyperparameters to be fit:\n",
      "{\n",
      "\t'NN_TORCH': {},\n",
      "\t'GBM': [{'extra_trees': True, 'ag_args': {'name_suffix': 'XT'}}, {}, 'GBMLarge'],\n",
      "\t'CAT': {},\n",
      "\t'XGB': {},\n",
      "\t'FASTAI': {},\n",
      "\t'RF': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],\n",
      "\t'XT': [{'criterion': 'gini', 'ag_args': {'name_suffix': 'Gini', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'entropy', 'ag_args': {'name_suffix': 'Entr', 'problem_types': ['binary', 'multiclass']}}, {'criterion': 'squared_error', 'ag_args': {'name_suffix': 'MSE', 'problem_types': ['regression', 'quantile']}}],\n",
      "\t'KNN': [{'weights': 'uniform', 'ag_args': {'name_suffix': 'Unif'}}, {'weights': 'distance', 'ag_args': {'name_suffix': 'Dist'}}],\n",
      "}\n",
      "Fitting 13 L1 models ...\n",
      "Fitting model: KNeighborsUnif ...\n",
      "\t0.653\t = Validation score   (accuracy)\n",
      "\t2.99s\t = Training   runtime\n",
      "\t0.35s\t = Validation runtime\n",
      "Fitting model: KNeighborsDist ...\n",
      "\t0.663\t = Validation score   (accuracy)\n",
      "\t0.05s\t = Training   runtime\n",
      "\t0.15s\t = Validation runtime\n",
      "Fitting model: LightGBMXT ...\n",
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/common/utils/pandas_utils.py:50: FutureWarning: Setting an item of incompatible dtype is deprecated and will raise in a future error of pandas. Value '18637.401015228428' has dtype incompatible with int64, please explicitly cast to a compatible dtype first.\n",
      "  memory_usage[column] = (\n",
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/dask/dataframe/__init__.py:31: FutureWarning: \n",
      "Dask dataframe query planning is disabled because dask-expr is not installed.\n",
      "\n",
      "You can install it with `pip install dask[dataframe]` or `conda install dask`.\n",
      "This will raise in a future version.\n",
      "\n",
      "  warnings.warn(msg, FutureWarning)\n",
      "\t0.848\t = Validation score   (accuracy)\n",
      "\t4.29s\t = Training   runtime\n",
      "\t0.06s\t = Validation runtime\n",
      "Fitting model: LightGBM ...\n",
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/common/utils/pandas_utils.py:50: FutureWarning: Setting an item of incompatible dtype is deprecated and will raise in a future error of pandas. Value '18637.401015228428' has dtype incompatible with int64, please explicitly cast to a compatible dtype first.\n",
      "  memory_usage[column] = (\n",
      "\t0.853\t = Validation score   (accuracy)\n",
      "\t3.05s\t = Training   runtime\n",
      "\t0.03s\t = Validation runtime\n",
      "Fitting model: RandomForestGini ...\n",
      "\t0.853\t = Validation score   (accuracy)\n",
      "\t5.8s\t = Training   runtime\n",
      "\t0.21s\t = Validation runtime\n",
      "Fitting model: RandomForestEntr ...\n",
      "\t0.85\t = Validation score   (accuracy)\n",
      "\t5.78s\t = Training   runtime\n",
      "\t0.22s\t = Validation runtime\n",
      "Fitting model: CatBoost ...\n",
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/common/utils/pandas_utils.py:50: FutureWarning: Setting an item of incompatible dtype is deprecated and will raise in a future error of pandas. Value '18637.401015228428' has dtype incompatible with int64, please explicitly cast to a compatible dtype first.\n",
      "  memory_usage[column] = (\n",
      "\t0.854\t = Validation score   (accuracy)\n",
      "\t17.46s\t = Training   runtime\n",
      "\t0.06s\t = Validation runtime\n",
      "Fitting model: ExtraTreesGini ...\n",
      "\t0.836\t = Validation score   (accuracy)\n",
      "\t5.81s\t = Training   runtime\n",
      "\t0.2s\t = Validation runtime\n",
      "Fitting model: ExtraTreesEntr ...\n",
      "\t0.844\t = Validation score   (accuracy)\n",
      "\t6.1s\t = Training   runtime\n",
      "\t0.19s\t = Validation runtime\n",
      "Fitting model: NeuralNetFastAI ...\n",
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/common/utils/pandas_utils.py:50: FutureWarning: Setting an item of incompatible dtype is deprecated and will raise in a future error of pandas. Value '18637.401015228428' has dtype incompatible with int64, please explicitly cast to a compatible dtype first.\n",
      "  memory_usage[column] = (\n",
      "No improvement since epoch 7: early stopping\n",
      "\t0.819\t = Validation score   (accuracy)\n",
      "\t34.9s\t = Training   runtime\n",
      "\t0.1s\t = Validation runtime\n",
      "Fitting model: XGBoost ...\n",
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/common/utils/pandas_utils.py:50: FutureWarning: Setting an item of incompatible dtype is deprecated and will raise in a future error of pandas. Value '18637.401015228428' has dtype incompatible with int64, please explicitly cast to a compatible dtype first.\n",
      "  memory_usage[column] = (\n",
      "\t0.854\t = Validation score   (accuracy)\n",
      "\t8.64s\t = Training   runtime\n",
      "\t0.02s\t = Validation runtime\n",
      "Fitting model: NeuralNetTorch ...\n",
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/common/utils/pandas_utils.py:50: FutureWarning: Setting an item of incompatible dtype is deprecated and will raise in a future error of pandas. Value '18637.401015228428' has dtype incompatible with int64, please explicitly cast to a compatible dtype first.\n",
      "  memory_usage[column] = (\n",
      "\t0.853\t = Validation score   (accuracy)\n",
      "\t45.23s\t = Training   runtime\n",
      "\t0.04s\t = Validation runtime\n",
      "Fitting model: LightGBMLarge ...\n",
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/common/utils/pandas_utils.py:50: FutureWarning: Setting an item of incompatible dtype is deprecated and will raise in a future error of pandas. Value '18637.401015228428' has dtype incompatible with int64, please explicitly cast to a compatible dtype first.\n",
      "  memory_usage[column] = (\n",
      "\t0.847\t = Validation score   (accuracy)\n",
      "\t4.97s\t = Training   runtime\n",
      "\t0.03s\t = Validation runtime\n",
      "Fitting model: WeightedEnsemble_L2 ...\n",
      "\t0.871\t = Validation score   (accuracy)\n",
      "\t2.38s\t = Training   runtime\n",
      "\t0.0s\t = Validation runtime\n",
      "AutoGluon training complete, total runtime = 161.57s ... Best model: \"WeightedEnsemble_L2\"\n",
      "TabularPredictor saved. To load, use: predictor = TabularPredictor.load(\"AutogluonModels/ag-20241001_183703/\")\n"
     ]
    }
   ],
   "source": [
    "from autogluon.tabular import TabularDataset, TabularPredictor\n",
    "\n",
    "k = 10000 # grab less data for a quick demo\n",
    "#k = train_data.shape[0] # grad the whole dataset\n",
    "\n",
    "predictor = TabularPredictor(label='Outcome Type').fit(train_data.head(k))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also summarize what happened during fit."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "*** Summary of fit() ***\n",
      "Estimated performance of each model:\n",
      "                  model  score_val  pred_time_val   fit_time  pred_time_val_marginal  fit_time_marginal  stack_level  can_infer  fit_order\n",
      "0   WeightedEnsemble_L2      0.871       0.625275  92.942297                0.004572           2.378834            2       True         14\n",
      "1               XGBoost      0.854       0.023012   8.635196                0.023012           8.635196            1       True         11\n",
      "2              CatBoost      0.854       0.064999  17.457963                0.064999          17.457963            1       True          7\n",
      "3              LightGBM      0.853       0.025760   3.052832                0.025760           3.052832            1       True          4\n",
      "4        NeuralNetTorch      0.853       0.039637  45.233796                0.039637          45.233796            1       True         12\n",
      "5      RandomForestGini      0.853       0.214309   5.797882                0.214309           5.797882            1       True          5\n",
      "6      RandomForestEntr      0.850       0.218737   5.784430                0.218737           5.784430            1       True          6\n",
      "7            LightGBMXT      0.848       0.061231   4.287037                0.061231           4.287037            1       True          3\n",
      "8         LightGBMLarge      0.847       0.030208   4.965655                0.030208           4.965655            1       True         13\n",
      "9        ExtraTreesEntr      0.844       0.191755   6.098757                0.191755           6.098757            1       True          9\n",
      "10       ExtraTreesGini      0.836       0.204940   5.812799                0.204940           5.812799            1       True          8\n",
      "11      NeuralNetFastAI      0.819       0.097966  34.902269                0.097966          34.902269            1       True         10\n",
      "12       KNeighborsDist      0.663       0.154423   0.054532                0.154423           0.054532            1       True          2\n",
      "13       KNeighborsUnif      0.653       0.347130   2.985137                0.347130           2.985137            1       True          1\n",
      "Number of models trained: 14\n",
      "Types of models trained:\n",
      "{'NNFastAiTabularModel', 'TabularNeuralNetTorchModel', 'WeightedEnsembleModel', 'RFModel', 'XGBoostModel', 'CatBoostModel', 'LGBModel', 'XTModel', 'KNNModel'}\n",
      "Bagging used: False \n",
      "Multi-layer stack-ensembling used: False \n",
      "Feature Metadata (Processed):\n",
      "(raw dtype, special dtypes):\n",
      "('category', [])                    :   8 | ['Sex upon Outcome', 'Name', 'Intake Type', 'Intake Condition', 'Pet Type', ...]\n",
      "('category', ['text_as_category'])  :   1 | ['Found Location']\n",
      "('int', [])                         :   2 | ['Age upon Intake Days', 'Age upon Outcome Days']\n",
      "('int', ['binned', 'text_special']) :  12 | ['Found Location.char_count', 'Found Location.word_count', 'Found Location.capital_ratio', 'Found Location.lower_ratio', 'Found Location.digit_ratio', ...]\n",
      "('int', ['text_ngram'])             : 176 | ['__nlp__.183', '__nlp__.183 and', '__nlp__.183 in', '__nlp__.1st', '__nlp__.290', ...]\n",
      "*** End of fit() summary ***\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/core/utils/plots.py:169: UserWarning: AutoGluon summary plots cannot be created because bokeh is not installed. To see plots, please do: \"pip install bokeh==2.0.1\"\n",
      "  warnings.warn('AutoGluon summary plots cannot be created because bokeh is not installed. To see plots, please do: \"pip install bokeh==2.0.1\"')\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'model_types': {'KNeighborsUnif': 'KNNModel',\n",
       "  'KNeighborsDist': 'KNNModel',\n",
       "  'LightGBMXT': 'LGBModel',\n",
       "  'LightGBM': 'LGBModel',\n",
       "  'RandomForestGini': 'RFModel',\n",
       "  'RandomForestEntr': 'RFModel',\n",
       "  'CatBoost': 'CatBoostModel',\n",
       "  'ExtraTreesGini': 'XTModel',\n",
       "  'ExtraTreesEntr': 'XTModel',\n",
       "  'NeuralNetFastAI': 'NNFastAiTabularModel',\n",
       "  'XGBoost': 'XGBoostModel',\n",
       "  'NeuralNetTorch': 'TabularNeuralNetTorchModel',\n",
       "  'LightGBMLarge': 'LGBModel',\n",
       "  'WeightedEnsemble_L2': 'WeightedEnsembleModel'},\n",
       " 'model_performance': {'KNeighborsUnif': 0.653,\n",
       "  'KNeighborsDist': 0.663,\n",
       "  'LightGBMXT': 0.848,\n",
       "  'LightGBM': 0.853,\n",
       "  'RandomForestGini': 0.853,\n",
       "  'RandomForestEntr': 0.85,\n",
       "  'CatBoost': 0.854,\n",
       "  'ExtraTreesGini': 0.836,\n",
       "  'ExtraTreesEntr': 0.844,\n",
       "  'NeuralNetFastAI': 0.819,\n",
       "  'XGBoost': 0.854,\n",
       "  'NeuralNetTorch': 0.853,\n",
       "  'LightGBMLarge': 0.847,\n",
       "  'WeightedEnsemble_L2': 0.871},\n",
       " 'model_best': 'WeightedEnsemble_L2',\n",
       " 'model_paths': {'KNeighborsUnif': 'AutogluonModels/ag-20241001_183703/models/KNeighborsUnif/',\n",
       "  'KNeighborsDist': 'AutogluonModels/ag-20241001_183703/models/KNeighborsDist/',\n",
       "  'LightGBMXT': 'AutogluonModels/ag-20241001_183703/models/LightGBMXT/',\n",
       "  'LightGBM': 'AutogluonModels/ag-20241001_183703/models/LightGBM/',\n",
       "  'RandomForestGini': 'AutogluonModels/ag-20241001_183703/models/RandomForestGini/',\n",
       "  'RandomForestEntr': 'AutogluonModels/ag-20241001_183703/models/RandomForestEntr/',\n",
       "  'CatBoost': 'AutogluonModels/ag-20241001_183703/models/CatBoost/',\n",
       "  'ExtraTreesGini': 'AutogluonModels/ag-20241001_183703/models/ExtraTreesGini/',\n",
       "  'ExtraTreesEntr': 'AutogluonModels/ag-20241001_183703/models/ExtraTreesEntr/',\n",
       "  'NeuralNetFastAI': 'AutogluonModels/ag-20241001_183703/models/NeuralNetFastAI/',\n",
       "  'XGBoost': 'AutogluonModels/ag-20241001_183703/models/XGBoost/',\n",
       "  'NeuralNetTorch': 'AutogluonModels/ag-20241001_183703/models/NeuralNetTorch/',\n",
       "  'LightGBMLarge': 'AutogluonModels/ag-20241001_183703/models/LightGBMLarge/',\n",
       "  'WeightedEnsemble_L2': 'AutogluonModels/ag-20241001_183703/models/WeightedEnsemble_L2/'},\n",
       " 'model_fit_times': {'KNeighborsUnif': 2.9851365089416504,\n",
       "  'KNeighborsDist': 0.05453157424926758,\n",
       "  'LightGBMXT': 4.287037134170532,\n",
       "  'LightGBM': 3.0528316497802734,\n",
       "  'RandomForestGini': 5.797881841659546,\n",
       "  'RandomForestEntr': 5.784429550170898,\n",
       "  'CatBoost': 17.45796298980713,\n",
       "  'ExtraTreesGini': 5.812798976898193,\n",
       "  'ExtraTreesEntr': 6.098757028579712,\n",
       "  'NeuralNetFastAI': 34.90226936340332,\n",
       "  'XGBoost': 8.6351957321167,\n",
       "  'NeuralNetTorch': 45.23379588127136,\n",
       "  'LightGBMLarge': 4.965655326843262,\n",
       "  'WeightedEnsemble_L2': 2.3788342475891113},\n",
       " 'model_pred_times': {'KNeighborsUnif': 0.34712982177734375,\n",
       "  'KNeighborsDist': 0.15442276000976562,\n",
       "  'LightGBMXT': 0.06123089790344238,\n",
       "  'LightGBM': 0.02575993537902832,\n",
       "  'RandomForestGini': 0.2143092155456543,\n",
       "  'RandomForestEntr': 0.21873688697814941,\n",
       "  'CatBoost': 0.06499910354614258,\n",
       "  'ExtraTreesGini': 0.204939603805542,\n",
       "  'ExtraTreesEntr': 0.1917552947998047,\n",
       "  'NeuralNetFastAI': 0.09796595573425293,\n",
       "  'XGBoost': 0.023012399673461914,\n",
       "  'NeuralNetTorch': 0.03963661193847656,\n",
       "  'LightGBMLarge': 0.030208110809326172,\n",
       "  'WeightedEnsemble_L2': 0.0045719146728515625},\n",
       " 'num_bag_folds': 0,\n",
       " 'max_stack_level': 2,\n",
       " 'num_classes': 2,\n",
       " 'model_hyperparams': {'KNeighborsUnif': {'weights': 'uniform'},\n",
       "  'KNeighborsDist': {'weights': 'distance'},\n",
       "  'LightGBMXT': {'learning_rate': 0.05, 'extra_trees': True},\n",
       "  'LightGBM': {'learning_rate': 0.05},\n",
       "  'RandomForestGini': {'n_estimators': 300,\n",
       "   'max_leaf_nodes': 15000,\n",
       "   'n_jobs': -1,\n",
       "   'random_state': 0,\n",
       "   'bootstrap': True,\n",
       "   'criterion': 'gini'},\n",
       "  'RandomForestEntr': {'n_estimators': 300,\n",
       "   'max_leaf_nodes': 15000,\n",
       "   'n_jobs': -1,\n",
       "   'random_state': 0,\n",
       "   'bootstrap': True,\n",
       "   'criterion': 'entropy'},\n",
       "  'CatBoost': {'iterations': 10000,\n",
       "   'learning_rate': 0.05,\n",
       "   'random_seed': 0,\n",
       "   'allow_writing_files': False,\n",
       "   'eval_metric': 'Accuracy'},\n",
       "  'ExtraTreesGini': {'n_estimators': 300,\n",
       "   'max_leaf_nodes': 15000,\n",
       "   'n_jobs': -1,\n",
       "   'random_state': 0,\n",
       "   'bootstrap': True,\n",
       "   'criterion': 'gini'},\n",
       "  'ExtraTreesEntr': {'n_estimators': 300,\n",
       "   'max_leaf_nodes': 15000,\n",
       "   'n_jobs': -1,\n",
       "   'random_state': 0,\n",
       "   'bootstrap': True,\n",
       "   'criterion': 'entropy'},\n",
       "  'NeuralNetFastAI': {'layers': None,\n",
       "   'emb_drop': 0.1,\n",
       "   'ps': 0.1,\n",
       "   'bs': 'auto',\n",
       "   'lr': 0.01,\n",
       "   'epochs': 'auto',\n",
       "   'early.stopping.min_delta': 0.0001,\n",
       "   'early.stopping.patience': 20,\n",
       "   'smoothing': 0.0},\n",
       "  'XGBoost': {'n_estimators': 10000,\n",
       "   'learning_rate': 0.1,\n",
       "   'n_jobs': -1,\n",
       "   'proc.max_category_levels': 100,\n",
       "   'objective': 'binary:logistic',\n",
       "   'booster': 'gbtree'},\n",
       "  'NeuralNetTorch': {'num_epochs': 500,\n",
       "   'epochs_wo_improve': 20,\n",
       "   'activation': 'relu',\n",
       "   'embedding_size_factor': 1.0,\n",
       "   'embed_exponent': 0.56,\n",
       "   'max_embedding_dim': 100,\n",
       "   'y_range': None,\n",
       "   'y_range_extend': 0.05,\n",
       "   'dropout_prob': 0.1,\n",
       "   'optimizer': 'adam',\n",
       "   'learning_rate': 0.0003,\n",
       "   'weight_decay': 1e-06,\n",
       "   'proc.embed_min_categories': 4,\n",
       "   'proc.impute_strategy': 'median',\n",
       "   'proc.max_category_levels': 100,\n",
       "   'proc.skew_threshold': 0.99,\n",
       "   'use_ngram_features': False,\n",
       "   'num_layers': 4,\n",
       "   'hidden_size': 128,\n",
       "   'max_batch_size': 512,\n",
       "   'use_batchnorm': False,\n",
       "   'loss_function': 'auto'},\n",
       "  'LightGBMLarge': {'learning_rate': 0.03,\n",
       "   'num_leaves': 128,\n",
       "   'feature_fraction': 0.9,\n",
       "   'min_data_in_leaf': 5},\n",
       "  'WeightedEnsemble_L2': {'use_orig_features': False,\n",
       "   'max_base_models': 25,\n",
       "   'max_base_models_per_type': 5,\n",
       "   'save_bag_folds': True}},\n",
       " 'leaderboard':                   model  score_val  pred_time_val   fit_time  \\\n",
       " 0   WeightedEnsemble_L2      0.871       0.625275  92.942297   \n",
       " 1               XGBoost      0.854       0.023012   8.635196   \n",
       " 2              CatBoost      0.854       0.064999  17.457963   \n",
       " 3              LightGBM      0.853       0.025760   3.052832   \n",
       " 4        NeuralNetTorch      0.853       0.039637  45.233796   \n",
       " 5      RandomForestGini      0.853       0.214309   5.797882   \n",
       " 6      RandomForestEntr      0.850       0.218737   5.784430   \n",
       " 7            LightGBMXT      0.848       0.061231   4.287037   \n",
       " 8         LightGBMLarge      0.847       0.030208   4.965655   \n",
       " 9        ExtraTreesEntr      0.844       0.191755   6.098757   \n",
       " 10       ExtraTreesGini      0.836       0.204940   5.812799   \n",
       " 11      NeuralNetFastAI      0.819       0.097966  34.902269   \n",
       " 12       KNeighborsDist      0.663       0.154423   0.054532   \n",
       " 13       KNeighborsUnif      0.653       0.347130   2.985137   \n",
       " \n",
       "     pred_time_val_marginal  fit_time_marginal  stack_level  can_infer  \\\n",
       " 0                 0.004572           2.378834            2       True   \n",
       " 1                 0.023012           8.635196            1       True   \n",
       " 2                 0.064999          17.457963            1       True   \n",
       " 3                 0.025760           3.052832            1       True   \n",
       " 4                 0.039637          45.233796            1       True   \n",
       " 5                 0.214309           5.797882            1       True   \n",
       " 6                 0.218737           5.784430            1       True   \n",
       " 7                 0.061231           4.287037            1       True   \n",
       " 8                 0.030208           4.965655            1       True   \n",
       " 9                 0.191755           6.098757            1       True   \n",
       " 10                0.204940           5.812799            1       True   \n",
       " 11                0.097966          34.902269            1       True   \n",
       " 12                0.154423           0.054532            1       True   \n",
       " 13                0.347130           2.985137            1       True   \n",
       " \n",
       "     fit_order  \n",
       " 0          14  \n",
       " 1          11  \n",
       " 2           7  \n",
       " 3           4  \n",
       " 4          12  \n",
       " 5           5  \n",
       " 6           6  \n",
       " 7           3  \n",
       " 8          13  \n",
       " 9           9  \n",
       " 10          8  \n",
       " 11         10  \n",
       " 12          2  \n",
       " 13          1  }"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictor.fit_summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. <a name=\"4\">Model evaluation</a>\n",
    "(<a href=\"#0\">Go to top</a>)\n",
    "\n",
    "Next, we load a separate test data to demonstrate how to make predictions on new examples at inference time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/features/generators/fillna.py:58: FutureWarning: The 'downcast' keyword in fillna is deprecated and will be removed in a future version. Use res.infer_objects(copy=False) to infer non-object dtype, or pd.to_numeric with the 'downcast' keyword to downcast numeric results.\n",
      "  X.fillna(self._fillna_feature_map, inplace=True, downcast=False)\n",
      "Evaluation: accuracy on test data: 0.8593570007330611\n",
      "Evaluations on test data:\n",
      "{\n",
      "    \"accuracy\": 0.8593570007330611,\n",
      "    \"balanced_accuracy\": 0.846874357214327,\n",
      "    \"mcc\": 0.7158381734029869,\n",
      "    \"f1\": 0.8834504903237004,\n",
      "    \"precision\": 0.8336062888961677,\n",
      "    \"recall\": 0.9396344840317519\n",
      "}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'accuracy': 0.8593570007330611,\n",
       " 'balanced_accuracy': 0.846874357214327,\n",
       " 'mcc': 0.7158381734029869,\n",
       " 'f1': 0.8834504903237004,\n",
       " 'precision': 0.8336062888961677,\n",
       " 'recall': 0.9396344840317519}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# First predictions\n",
    "y_pred = predictor.predict(test_data.head(k))\n",
    "\n",
    "# Then, evaluations\n",
    "predictor.evaluate_predictions(y_true=test_data['Outcome Type'],\n",
    "                               y_pred=y_pred,\n",
    "                               auxiliary_metrics=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see the performance of each individual trained model on the test data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/envs/sagemaker-distribution/lib/python3.10/site-packages/autogluon/features/generators/fillna.py:58: FutureWarning: The 'downcast' keyword in fillna is deprecated and will be removed in a future version. Use res.infer_objects(copy=False) to infer non-object dtype, or pd.to_numeric with the 'downcast' keyword to downcast numeric results.\n",
      "  X.fillna(self._fillna_feature_map, inplace=True, downcast=False)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>model</th>\n",
       "      <th>score_test</th>\n",
       "      <th>score_val</th>\n",
       "      <th>pred_time_test</th>\n",
       "      <th>pred_time_val</th>\n",
       "      <th>fit_time</th>\n",
       "      <th>pred_time_test_marginal</th>\n",
       "      <th>pred_time_val_marginal</th>\n",
       "      <th>fit_time_marginal</th>\n",
       "      <th>stack_level</th>\n",
       "      <th>can_infer</th>\n",
       "      <th>fit_order</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>WeightedEnsemble_L2</td>\n",
       "      <td>0.859357</td>\n",
       "      <td>0.871</td>\n",
       "      <td>3.335518</td>\n",
       "      <td>0.625275</td>\n",
       "      <td>92.942297</td>\n",
       "      <td>0.006877</td>\n",
       "      <td>0.004572</td>\n",
       "      <td>2.378834</td>\n",
       "      <td>2</td>\n",
       "      <td>True</td>\n",
       "      <td>14</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>RandomForestEntr</td>\n",
       "      <td>0.855482</td>\n",
       "      <td>0.850</td>\n",
       "      <td>1.125673</td>\n",
       "      <td>0.218737</td>\n",
       "      <td>5.784430</td>\n",
       "      <td>1.125673</td>\n",
       "      <td>0.218737</td>\n",
       "      <td>5.784430</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>CatBoost</td>\n",
       "      <td>0.854016</td>\n",
       "      <td>0.854</td>\n",
       "      <td>0.107882</td>\n",
       "      <td>0.064999</td>\n",
       "      <td>17.457963</td>\n",
       "      <td>0.107882</td>\n",
       "      <td>0.064999</td>\n",
       "      <td>17.457963</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>RandomForestGini</td>\n",
       "      <td>0.854016</td>\n",
       "      <td>0.853</td>\n",
       "      <td>0.954611</td>\n",
       "      <td>0.214309</td>\n",
       "      <td>5.797882</td>\n",
       "      <td>0.954611</td>\n",
       "      <td>0.214309</td>\n",
       "      <td>5.797882</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>LightGBM</td>\n",
       "      <td>0.850141</td>\n",
       "      <td>0.853</td>\n",
       "      <td>0.200560</td>\n",
       "      <td>0.025760</td>\n",
       "      <td>3.052832</td>\n",
       "      <td>0.200560</td>\n",
       "      <td>0.025760</td>\n",
       "      <td>3.052832</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>XGBoost</td>\n",
       "      <td>0.849618</td>\n",
       "      <td>0.854</td>\n",
       "      <td>0.175631</td>\n",
       "      <td>0.023012</td>\n",
       "      <td>8.635196</td>\n",
       "      <td>0.175631</td>\n",
       "      <td>0.023012</td>\n",
       "      <td>8.635196</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>NeuralNetTorch</td>\n",
       "      <td>0.846895</td>\n",
       "      <td>0.853</td>\n",
       "      <td>0.158293</td>\n",
       "      <td>0.039637</td>\n",
       "      <td>45.233796</td>\n",
       "      <td>0.158293</td>\n",
       "      <td>0.039637</td>\n",
       "      <td>45.233796</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>LightGBMLarge</td>\n",
       "      <td>0.846686</td>\n",
       "      <td>0.847</td>\n",
       "      <td>0.137499</td>\n",
       "      <td>0.030208</td>\n",
       "      <td>4.965655</td>\n",
       "      <td>0.137499</td>\n",
       "      <td>0.030208</td>\n",
       "      <td>4.965655</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>LightGBMXT</td>\n",
       "      <td>0.846686</td>\n",
       "      <td>0.848</td>\n",
       "      <td>0.347909</td>\n",
       "      <td>0.061231</td>\n",
       "      <td>4.287037</td>\n",
       "      <td>0.347909</td>\n",
       "      <td>0.061231</td>\n",
       "      <td>4.287037</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>ExtraTreesGini</td>\n",
       "      <td>0.842811</td>\n",
       "      <td>0.836</td>\n",
       "      <td>1.077193</td>\n",
       "      <td>0.204940</td>\n",
       "      <td>5.812799</td>\n",
       "      <td>1.077193</td>\n",
       "      <td>0.204940</td>\n",
       "      <td>5.812799</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>ExtraTreesEntr</td>\n",
       "      <td>0.841030</td>\n",
       "      <td>0.844</td>\n",
       "      <td>1.383755</td>\n",
       "      <td>0.191755</td>\n",
       "      <td>6.098757</td>\n",
       "      <td>1.383755</td>\n",
       "      <td>0.191755</td>\n",
       "      <td>6.098757</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>NeuralNetFastAI</td>\n",
       "      <td>0.827416</td>\n",
       "      <td>0.819</td>\n",
       "      <td>0.574561</td>\n",
       "      <td>0.097966</td>\n",
       "      <td>34.902269</td>\n",
       "      <td>0.574561</td>\n",
       "      <td>0.097966</td>\n",
       "      <td>34.902269</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>10</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>KNeighborsDist</td>\n",
       "      <td>0.651168</td>\n",
       "      <td>0.663</td>\n",
       "      <td>1.443214</td>\n",
       "      <td>0.154423</td>\n",
       "      <td>0.054532</td>\n",
       "      <td>1.443214</td>\n",
       "      <td>0.154423</td>\n",
       "      <td>0.054532</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>KNeighborsUnif</td>\n",
       "      <td>0.648654</td>\n",
       "      <td>0.653</td>\n",
       "      <td>1.576504</td>\n",
       "      <td>0.347130</td>\n",
       "      <td>2.985137</td>\n",
       "      <td>1.576504</td>\n",
       "      <td>0.347130</td>\n",
       "      <td>2.985137</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  model  score_test  score_val  pred_time_test  pred_time_val  \\\n",
       "0   WeightedEnsemble_L2    0.859357      0.871        3.335518       0.625275   \n",
       "1      RandomForestEntr    0.855482      0.850        1.125673       0.218737   \n",
       "2              CatBoost    0.854016      0.854        0.107882       0.064999   \n",
       "3      RandomForestGini    0.854016      0.853        0.954611       0.214309   \n",
       "4              LightGBM    0.850141      0.853        0.200560       0.025760   \n",
       "5               XGBoost    0.849618      0.854        0.175631       0.023012   \n",
       "6        NeuralNetTorch    0.846895      0.853        0.158293       0.039637   \n",
       "7         LightGBMLarge    0.846686      0.847        0.137499       0.030208   \n",
       "8            LightGBMXT    0.846686      0.848        0.347909       0.061231   \n",
       "9        ExtraTreesGini    0.842811      0.836        1.077193       0.204940   \n",
       "10       ExtraTreesEntr    0.841030      0.844        1.383755       0.191755   \n",
       "11      NeuralNetFastAI    0.827416      0.819        0.574561       0.097966   \n",
       "12       KNeighborsDist    0.651168      0.663        1.443214       0.154423   \n",
       "13       KNeighborsUnif    0.648654      0.653        1.576504       0.347130   \n",
       "\n",
       "     fit_time  pred_time_test_marginal  pred_time_val_marginal  \\\n",
       "0   92.942297                 0.006877                0.004572   \n",
       "1    5.784430                 1.125673                0.218737   \n",
       "2   17.457963                 0.107882                0.064999   \n",
       "3    5.797882                 0.954611                0.214309   \n",
       "4    3.052832                 0.200560                0.025760   \n",
       "5    8.635196                 0.175631                0.023012   \n",
       "6   45.233796                 0.158293                0.039637   \n",
       "7    4.965655                 0.137499                0.030208   \n",
       "8    4.287037                 0.347909                0.061231   \n",
       "9    5.812799                 1.077193                0.204940   \n",
       "10   6.098757                 1.383755                0.191755   \n",
       "11  34.902269                 0.574561                0.097966   \n",
       "12   0.054532                 1.443214                0.154423   \n",
       "13   2.985137                 1.576504                0.347130   \n",
       "\n",
       "    fit_time_marginal  stack_level  can_infer  fit_order  \n",
       "0            2.378834            2       True         14  \n",
       "1            5.784430            1       True          6  \n",
       "2           17.457963            1       True          7  \n",
       "3            5.797882            1       True          5  \n",
       "4            3.052832            1       True          4  \n",
       "5            8.635196            1       True         11  \n",
       "6           45.233796            1       True         12  \n",
       "7            4.965655            1       True         13  \n",
       "8            4.287037            1       True          3  \n",
       "9            5.812799            1       True          8  \n",
       "10           6.098757            1       True          9  \n",
       "11          34.902269            1       True         10  \n",
       "12           0.054532            1       True          2  \n",
       "13           2.985137            1       True          1  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictor.leaderboard(test_data, silent=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. <a name=\"5\">Clean up model artifacts</a>\n",
    "(<a href=\"#0\">Go to top</a>)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!rm -r AutogluonModels"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sagemaker-distribution:Python",
   "language": "python",
   "name": "conda-env-sagemaker-distribution-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
